import torch
import os


class Config:
    """
     AG_NEWS 文本分类任务模型配置
    """

    def __init__(self):
        # 数据集相关配置
        self.project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        self.dataset_dir = os.path.join(self.project_dir, 'data')

        #  模型相关配置
        self.lr = 5
        self.batch_size = 16
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.epochs = 3
        self.model_save_dir = os.path.join(self.project_dir, 'checkpoints')
        if not os.path.exists(self.model_save_dir):
            os.makedirs(self.model_save_dir)
        self.log_dir = os.path.join(self.project_dir, 'logs')

        self.num_class = 4
        self.embed_dim = 32


cfg = Config()
